{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.core import Dense\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras.layers.wrappers import TimeDistributed\n",
    "from keras import backend as K\n",
    "import json\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TimeDistributed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**[wrappers.TimeDistributed.0] wrap a Dense layer with units 4 (input: 3 x 6)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W shape: (6, 4)\n",
      "W: [0.317596, 0.688515, -0.688309, -0.48247, 0.387223, -0.718263, 0.281673, -0.106311, 0.576861, -0.083926, 0.631691, 0.92647, 0.579655, -0.024215, -0.805793, -0.842947, -0.955415, 0.656415, 0.44667, 0.633739, 0.701525, 0.917507, -0.185671, -0.105247]\n",
      "b shape: (4,)\n",
      "b: [-0.332867, 0.650317, 0.995501, -0.458367]\n",
      "\n",
      "in shape: (3, 6)\n",
      "in: [-0.30351, 0.37881, -0.248093, 0.372204, -0.698964, -0.408058, -0.103801, 0.376217, -0.724015, 0.708616, -0.513219, -0.46074, -0.125163, -0.76111, -0.153798, 0.729255, 0.556458, -0.671966]\n",
      "out shape: (3, 4)\n",
      "out: [0.171595, -0.652137, 0.618031, -1.295817, -0.05994, -0.407387, 0.000875, -1.993142, -1.33639, 0.854801, 0.555804, -0.650907]\n"
     ]
    }
   ],
   "source": [
    "data_in_shape = (3, 6)\n",
    "\n",
    "layer_0 = Input(shape=data_in_shape)\n",
    "layer_1 = TimeDistributed(Dense(4))(layer_0)\n",
    "model = Model(inputs=layer_0, outputs=layer_1)\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(4000 + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "weight_names = ['W', 'b']\n",
    "for w_i, w_name in enumerate(weight_names):\n",
    "    print('{} shape:'.format(w_name), weights[w_i].shape)\n",
    "    print('{}:'.format(w_name), format_decimal(weights[w_i].ravel().tolist()))\n",
    "\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "print('')\n",
    "print('in shape:', data_in_shape)\n",
    "print('in:', data_in_formatted)\n",
    "print('out shape:', data_out_shape)\n",
    "print('out:', data_out_formatted)\n",
    "\n",
    "DATA['wrappers.TimeDistributed.0'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**[wrappers.TimeDistributed.1] wrap a Conv2D layer with 6 3x3 filters (input: 5x4x4x2)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W shape: (3, 3, 2, 6)\n",
      "W: [0.971827, -0.898904, -0.987921, 0.529589, 0.043586, -0.541366, 0.316759, 0.351387, -0.292323, 0.445466, -0.922655, 0.437413, -0.483267, -0.478014, 0.7408, -0.595028, -0.718381, 0.349594, -0.091293, 0.14291, 0.633818, -0.686841, -0.925272, -0.740397, 0.070594, 0.67408, 0.455314, -0.402251, 0.288807, 0.001378, 0.42892, -0.251869, 0.06113, -0.703784, 0.002676, 0.965023, 0.758788, 0.1193, 0.749321, -0.017408, -0.004115, 0.18981, -0.91507, 0.132792, -0.219057, 0.19682, -0.512841, 0.954544, 0.794403, -0.663179, -0.05377, -0.855038, -0.486641, 0.625844, -0.945869, -0.474979, 0.922345, -0.334843, -0.469456, -0.394364, 0.543681, -0.817676, 0.6093, -0.77635, -0.508683, 0.22456, 0.696262, 0.079806, -0.182646, -0.718939, 0.962504, -0.386231, 0.860488, -0.918945, -0.800484, -0.590285, 0.409804, -0.822098, 0.3489, -0.4508, 0.913208, -0.414455, 0.97663, 0.956314, -0.55547, 0.594094, -0.552044, -0.137467, 0.539049, -0.320055, -0.335577, 0.974746, -0.634747, 0.085161, -0.127183, -0.061717, -0.411844, 0.774181, 0.223395, 0.163937, -0.606967, 0.178549, -0.005153, 0.452476, 0.373127, -0.726827, -0.395458, -0.769671]\n",
      "b shape: (6,)\n",
      "b: [0.180389, 0.629217, -0.656262, -0.476575, -0.36398, 0.987756]\n",
      "\n",
      "in shape: (5, 4, 4, 2)\n",
      "in: [-0.579677, 0.883193, 0.651172, -0.820251, -0.64795, 0.857328, -0.4689, 0.356044, -0.641528, -0.531973, -0.33586, -0.438823, 0.682186, 0.215781, -0.401735, 0.169171, 0.869358, -0.204078, -0.661876, -0.616139, -0.453943, -0.569439, -0.25218, 0.156473, 0.194797, -0.923921, 0.652204, -0.11765, 0.86293, 0.314218, -0.878496, -0.364761, -0.647821, 0.296841, 0.280105, 0.2753, -0.959741, -0.148037, -0.489424, -0.88939, 0.704443, 0.08354, 0.930112, -0.87023, -0.212285, 0.750133, 0.343506, -0.82568, 0.391491, 0.149626, 0.003594, -0.181464, -0.499632, 0.20694, 0.1007, 0.39826, 0.609736, -0.765775, -0.728474, -0.011711, 0.543543, 0.174309, 0.105794, -0.009876, -0.694421, -0.157031, 0.670853, -0.581331, 0.739486, -0.886014, -0.637039, 0.725753, 0.61919, 0.447635, 0.167298, 0.164242, -0.615436, -0.503061, 0.981698, -0.392795, 0.532215, 0.761817, 0.735562, -0.236234, -0.856381, 0.22419, -0.221125, 0.133757, -0.011162, -0.88018, -0.433047, -0.825617, 0.693626, -0.185243, -0.824829, 0.07932, 0.336478, 0.370138, -0.685905, -0.462037, 0.563862, 0.490274, 0.934239, -0.129323, 0.717792, -0.73658, -0.939587, 0.796637, -0.131382, -0.79957, -0.271279, 0.816961, -0.082096, 0.64553, -0.106661, 0.651369, -0.843208, -0.221077, 0.758074, 0.156006, -0.429501, 0.191698, 0.988067, -0.277344, 0.757645, -0.877824, 0.053841, 0.394075, 0.786359, 0.735302, 0.247852, -0.310899, 0.703408, -0.848404, 0.455067, 0.295289, -0.629316, 0.626332, -0.075289, -0.442735, -0.219408, -0.766048, 0.303257, 0.142211, 0.910002, -0.780858, 0.333242, -0.533434, 0.572575, 0.355883, -0.671924, 0.22028, -0.505951, -0.317892, 0.609641, -0.360548, 0.490007, 0.441024, 0.660294, 0.850007]\n",
      "out shape: (5, 2, 2, 6)\n",
      "out: [2.089554, -2.186939, -1.436176, -0.951733, -0.212962, 2.449681, 1.053569, -0.592297, -0.875753, -0.803289, -0.834779, -0.568349, -0.842922, 3.976765, -1.054281, 0.581773, 0.235047, 0.103039, -0.079684, 0.225164, -2.408352, -1.116154, 1.561833, -0.491674, 2.43274, -0.158393, -0.874487, -1.96851, -0.106465, 1.602375, 0.941225, 0.480547, 0.002478, 1.246195, -1.388929, -1.133004, 1.476556, -0.459852, -2.130519, -0.126113, -1.162246, 1.398016, -0.61384, 1.539333, -0.466156, 0.0395, 0.506595, -1.590958, -1.044266, 0.736233, 0.61792, -0.923799, 1.275832, 1.491487, 1.903216, -2.385963, -1.553725, -0.554848, -0.456638, 1.645426, 0.690056, 0.190637, -2.015925, 1.143469, -2.530136, 1.025159, -0.150503, 2.627801, -1.352068, 1.245647, 1.235627, -0.915363, 0.682646, 0.854592, -0.030856, 0.949627, 1.204568, 1.052329, -0.942961, 2.039314, 0.892454, -1.925232, 0.046332, 2.315713, -2.358422, 1.724373, -1.528506, 1.794933, 0.342617, -0.191888, -0.026605, 0.475714, -1.332559, -1.158213, 0.028725, 1.890396, -0.305622, 0.890336, -3.426138, 1.245994, -2.027975, -0.505022, 1.32001, 0.477823, -2.460816, -0.984189, 1.221664, 0.339475, 1.26535, 2.228118, 0.207158, -0.455112, -0.64988, 0.688864, 0.574933, 1.911587, -1.642423, -1.385077, 0.744757, -0.567276]\n"
     ]
    }
   ],
   "source": [
    "data_in_shape = (5, 4, 4, 2)\n",
    "\n",
    "layer_0 = Input(shape=data_in_shape)\n",
    "layer_1 = TimeDistributed(Conv2D(6, (3,3), data_format='channels_last'))(layer_0)\n",
    "model = Model(inputs=layer_0, outputs=layer_1)\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(4010 + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "weight_names = ['W', 'b']\n",
    "for w_i, w_name in enumerate(weight_names):\n",
    "    print('{} shape:'.format(w_name), weights[w_i].shape)\n",
    "    print('{}:'.format(w_name), format_decimal(weights[w_i].ravel().tolist()))\n",
    "\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "print('')\n",
    "print('in shape:', data_in_shape)\n",
    "print('in:', data_in_formatted)\n",
    "print('out shape:', data_out_shape)\n",
    "print('out:', data_out_formatted)\n",
    "\n",
    "DATA['wrappers.TimeDistributed.1'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"wrappers.TimeDistributed.0\": {\"expected\": {\"data\": [0.171595, -0.652137, 0.618031, -1.295817, -0.05994, -0.407387, 0.000875, -1.993142, -1.33639, 0.854801, 0.555804, -0.650907], \"shape\": [3, 4]}, \"input\": {\"data\": [-0.30351, 0.37881, -0.248093, 0.372204, -0.698964, -0.408058, -0.103801, 0.376217, -0.724015, 0.708616, -0.513219, -0.46074, -0.125163, -0.76111, -0.153798, 0.729255, 0.556458, -0.671966], \"shape\": [3, 6]}, \"weights\": [{\"data\": [0.317596, 0.688515, -0.688309, -0.48247, 0.387223, -0.718263, 0.281673, -0.106311, 0.576861, -0.083926, 0.631691, 0.92647, 0.579655, -0.024215, -0.805793, -0.842947, -0.955415, 0.656415, 0.44667, 0.633739, 0.701525, 0.917507, -0.185671, -0.105247], \"shape\": [6, 4]}, {\"data\": [-0.332867, 0.650317, 0.995501, -0.458367], \"shape\": [4]}]}, \"wrappers.TimeDistributed.1\": {\"expected\": {\"data\": [2.089554, -2.186939, -1.436176, -0.951733, -0.212962, 2.449681, 1.053569, -0.592297, -0.875753, -0.803289, -0.834779, -0.568349, -0.842922, 3.976765, -1.054281, 0.581773, 0.235047, 0.103039, -0.079684, 0.225164, -2.408352, -1.116154, 1.561833, -0.491674, 2.43274, -0.158393, -0.874487, -1.96851, -0.106465, 1.602375, 0.941225, 0.480547, 0.002478, 1.246195, -1.388929, -1.133004, 1.476556, -0.459852, -2.130519, -0.126113, -1.162246, 1.398016, -0.61384, 1.539333, -0.466156, 0.0395, 0.506595, -1.590958, -1.044266, 0.736233, 0.61792, -0.923799, 1.275832, 1.491487, 1.903216, -2.385963, -1.553725, -0.554848, -0.456638, 1.645426, 0.690056, 0.190637, -2.015925, 1.143469, -2.530136, 1.025159, -0.150503, 2.627801, -1.352068, 1.245647, 1.235627, -0.915363, 0.682646, 0.854592, -0.030856, 0.949627, 1.204568, 1.052329, -0.942961, 2.039314, 0.892454, -1.925232, 0.046332, 2.315713, -2.358422, 1.724373, -1.528506, 1.794933, 0.342617, -0.191888, -0.026605, 0.475714, -1.332559, -1.158213, 0.028725, 1.890396, -0.305622, 0.890336, -3.426138, 1.245994, -2.027975, -0.505022, 1.32001, 0.477823, -2.460816, -0.984189, 1.221664, 0.339475, 1.26535, 2.228118, 0.207158, -0.455112, -0.64988, 0.688864, 0.574933, 1.911587, -1.642423, -1.385077, 0.744757, -0.567276], \"shape\": [5, 2, 2, 6]}, \"input\": {\"data\": [-0.579677, 0.883193, 0.651172, -0.820251, -0.64795, 0.857328, -0.4689, 0.356044, -0.641528, -0.531973, -0.33586, -0.438823, 0.682186, 0.215781, -0.401735, 0.169171, 0.869358, -0.204078, -0.661876, -0.616139, -0.453943, -0.569439, -0.25218, 0.156473, 0.194797, -0.923921, 0.652204, -0.11765, 0.86293, 0.314218, -0.878496, -0.364761, -0.647821, 0.296841, 0.280105, 0.2753, -0.959741, -0.148037, -0.489424, -0.88939, 0.704443, 0.08354, 0.930112, -0.87023, -0.212285, 0.750133, 0.343506, -0.82568, 0.391491, 0.149626, 0.003594, -0.181464, -0.499632, 0.20694, 0.1007, 0.39826, 0.609736, -0.765775, -0.728474, -0.011711, 0.543543, 0.174309, 0.105794, -0.009876, -0.694421, -0.157031, 0.670853, -0.581331, 0.739486, -0.886014, -0.637039, 0.725753, 0.61919, 0.447635, 0.167298, 0.164242, -0.615436, -0.503061, 0.981698, -0.392795, 0.532215, 0.761817, 0.735562, -0.236234, -0.856381, 0.22419, -0.221125, 0.133757, -0.011162, -0.88018, -0.433047, -0.825617, 0.693626, -0.185243, -0.824829, 0.07932, 0.336478, 0.370138, -0.685905, -0.462037, 0.563862, 0.490274, 0.934239, -0.129323, 0.717792, -0.73658, -0.939587, 0.796637, -0.131382, -0.79957, -0.271279, 0.816961, -0.082096, 0.64553, -0.106661, 0.651369, -0.843208, -0.221077, 0.758074, 0.156006, -0.429501, 0.191698, 0.988067, -0.277344, 0.757645, -0.877824, 0.053841, 0.394075, 0.786359, 0.735302, 0.247852, -0.310899, 0.703408, -0.848404, 0.455067, 0.295289, -0.629316, 0.626332, -0.075289, -0.442735, -0.219408, -0.766048, 0.303257, 0.142211, 0.910002, -0.780858, 0.333242, -0.533434, 0.572575, 0.355883, -0.671924, 0.22028, -0.505951, -0.317892, 0.609641, -0.360548, 0.490007, 0.441024, 0.660294, 0.850007], \"shape\": [5, 4, 4, 2]}, \"weights\": [{\"data\": [0.971827, -0.898904, -0.987921, 0.529589, 0.043586, -0.541366, 0.316759, 0.351387, -0.292323, 0.445466, -0.922655, 0.437413, -0.483267, -0.478014, 0.7408, -0.595028, -0.718381, 0.349594, -0.091293, 0.14291, 0.633818, -0.686841, -0.925272, -0.740397, 0.070594, 0.67408, 0.455314, -0.402251, 0.288807, 0.001378, 0.42892, -0.251869, 0.06113, -0.703784, 0.002676, 0.965023, 0.758788, 0.1193, 0.749321, -0.017408, -0.004115, 0.18981, -0.91507, 0.132792, -0.219057, 0.19682, -0.512841, 0.954544, 0.794403, -0.663179, -0.05377, -0.855038, -0.486641, 0.625844, -0.945869, -0.474979, 0.922345, -0.334843, -0.469456, -0.394364, 0.543681, -0.817676, 0.6093, -0.77635, -0.508683, 0.22456, 0.696262, 0.079806, -0.182646, -0.718939, 0.962504, -0.386231, 0.860488, -0.918945, -0.800484, -0.590285, 0.409804, -0.822098, 0.3489, -0.4508, 0.913208, -0.414455, 0.97663, 0.956314, -0.55547, 0.594094, -0.552044, -0.137467, 0.539049, -0.320055, -0.335577, 0.974746, -0.634747, 0.085161, -0.127183, -0.061717, -0.411844, 0.774181, 0.223395, 0.163937, -0.606967, 0.178549, -0.005153, 0.452476, 0.373127, -0.726827, -0.395458, -0.769671], \"shape\": [3, 3, 2, 6]}, {\"data\": [0.180389, 0.629217, -0.656262, -0.476575, -0.36398, 0.987756], \"shape\": [6]}]}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
